package cool.houge.ws.impl;

import static io.netty.handler.codec.http.HttpHeaderNames.AUTHORIZATION;

import com.typesafe.config.Config;
import cool.houge.ws.AuthHandler;
import cool.houge.ws.BadAuthorizationException;
import cool.houge.ws.ConfigKeys;
import cool.houge.ws.jjwt.JsonbDeserializer;
import cool.houge.ws.session.DefaultSession;
import cool.houge.ws.session.Session;
import io.avaje.jsonb.Jsonb;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.SigningKeyResolverAdapter;
import io.jsonwebtoken.impl.DefaultJwtParserBuilder;
import io.jsonwebtoken.security.Keys;
import io.netty.handler.codec.http.QueryStringDecoder;
import jakarta.inject.Singleton;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import reactor.core.publisher.Mono;
import reactor.netty.http.server.HttpServerRequest;
import reactor.netty.http.websocket.WebsocketInbound;
import reactor.netty.http.websocket.WebsocketOutbound;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;

/**
 * @author ZY (kzou227@qq.com)
 */
@Singleton
public class JwtAuthHandler implements AuthHandler {

    private static final Logger log = LogManager.getLogger(JwtAuthHandler.class);

    private static final String ACCESS_TOKEN_QUERY = "access_token";
    private static final String AUTHORIZATION_SCHEME = "Bearer";

    private final Config config;

    public JwtAuthHandler(Config config) {
        this.config = config;
    }

    @Override
    public Mono<Session> handle(WebsocketInbound inbound, WebsocketOutbound outbound) {
        return getToken(inbound)
                .flatMap(this::resolve)
                .map(t -> new DefaultSession(inbound, outbound, t.getT1(), t.getT2()));
    }

    Mono<Tuple2<String, String>> resolve(String token) {
        var parser = new DefaultJwtParserBuilder()
                .setSigningKeyResolver(new SigningKeyResolverAdapter() {
                    @Override
                    public Key resolveSigningKey(JwsHeader header, Claims claims) {
                        var kid = header.getKeyId();
                        var value = config.getConfig(ConfigKeys.JWT_SECRETS).getString(kid);
                        return Keys.hmacShaKeyFor(value.getBytes(StandardCharsets.UTF_8));
                    }
                })
                .deserializeJsonWith(new JsonbDeserializer(Jsonb.builder().build()))
                .build();

        var jws = parser.parseClaimsJws(token);
        var uid = jws.getBody().getSubject();
        log.info("访问令牌解析成功 uid={} jid={}", uid, jws.getBody().getId());

        return Mono.just(Tuples.of(uid, token));
    }

    Mono<String> getToken(WebsocketInbound inbound) {
        return Mono.justOrEmpty(inbound.headers().get(AUTHORIZATION))
                .flatMap(authorization -> {
                    var values = authorization.split(" ");
                    if (values.length != 2 || !AUTHORIZATION_SCHEME.equals(values[0])) {
                        return Mono.error(new BadAuthorizationException("非法的AUTHORIZATION：" + authorization));
                    }
                    return Mono.just(values[1]);
                })
                .switchIfEmpty(Mono.defer(() -> {
                    var request = (HttpServerRequest) inbound;
                    var query = new QueryStringDecoder(request.uri());
                    var values = query.parameters().get(ACCESS_TOKEN_QUERY);
                    if (values == null || values.isEmpty()) {
                        return Mono.empty();
                    }
                    return Mono.just(values.get(0));
                }));
    }
}
